import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import torch.nn.functional as F
import torchvision.transforms as T
import os
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
from scipy.optimize import linear_sum_assignment
import numpy as np



def image_embedding(model, device, image):

    image = image.to(device)
    model = model.to(device)

    vision_model = model.vision_model

    patch_tokens = None
    cls_token = None

    def hook_fn(module, input, output):
        nonlocal patch_tokens, cls_token
        cls_token = output[:, 0, :]
        patch_tokens = output[:, 1:, :]

    hook = vision_model.embeddings.register_forward_hook(hook_fn)
    outputs = model.get_image_features(pixel_values=image)
    hook.remove()

    return cls_token, patch_tokens

def text_embedding(model, processor, device, text):
    inputs = processor(text=text, return_tensors="pt").to(device)

    text_model = model.text_model

    cls_token = None
    patch_tokens = None

    def hook_fn(module, input, output):
        nonlocal cls_token, patch_tokens

        cls_token = output[0][:, 0, :]
        patch_tokens = output[0][:, 1:, :]
    hook = text_model.encoder.layers[11].register_forward_hook(hook_fn)


    with torch.no_grad():
        outputs = text_model(**inputs)

    hook.remove()
    return cls_token


def get_cls_attention_map(model, device, image):
    image = image.to(device)
    model = model.to(device)

    vision_model = model.vision_model

    attention_map = None

    def hook_fn(module, input, output):
        nonlocal attention_map
        attention_map = output[1]
    hook = vision_model.encoder.layers[-1].self_attn.register_forward_hook(hook_fn)

    with torch.no_grad():
        outputs = model.get_image_features(pixel_values=image, output_attentions=True)

    attn_map_cls = attention_map[:, :, 0, 1:]

    hook.remove()

    return attn_map_cls

def get_text_sim_map(patch_tokens, text_embs, device='cuda'):

    text_embs = text_embs.squeeze(0)
    image_embs = patch_tokens.squeeze(0)

    cosine_sim = F.cosine_similarity(image_embs, text_embs.unsqueeze(0), dim=-1)

    k = int(cosine_sim.numel() * 0.3)
    threshold = torch.kthvalue(cosine_sim.view(-1), k).values

    map = (cosine_sim > threshold).float().to(device)
    return  map



